{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Kso-Y7_7Y6Cm"
      },
      "source": [
        "# Configuring Sonnet's BatchNorm Module\n",
        "\n",
        "This colab walks you through Sonnet's BatchNorm module's different modes of operation. \n",
        "\n",
        "The module's behaviour is determined by three main parameters: One constructor argument (```update_ops_collection```) and two arguments that are passed to the graph builder (```is_training``` and ```test_local_stats```). \n",
        "\n",
        "```python\n",
        "bn = BatchNorm(update_ops_collection)\n",
        "bn(inputs, is_training, test_local_stats)\n",
        "\n",
        "```\n",
        "\n",
        "The following diagram visualizes how different parameter settings lead to different modes of operation. Bold arrows mark the current default values of the arguments."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "cellView": "form",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "output_extras": [
            {
              "item_id": 1
            }
          ]
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 145,
          "status": "ok",
          "timestamp": 1499163928520,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "7bCTwDfsbPXj",
        "outputId": "a5c7a656-746e-4253-ded0-90e7ff36a67f"
      },
      "outputs": [
        {
          "data": {
            "image/svg+xml": "\u003csvg content=\"\u0026lt;mxfile userAgent=\u0026quot;Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36\u0026quot; version=\u0026quot;6.7.3\u0026quot; editor=\u0026quot;www.draw.io\u0026quot;\u0026gt;\u0026lt;diagram name=\u0026quot;Page-1\u0026quot;\u0026gt;7Vptc5s4EP41fLQHkMHxx9pNrh96mZtJb3r3ySODDLoKxAnhl/76roR4C7bHl5CEXOuZTORHy2q1++yuZNtCq+Twm8BZ/DsPCbNcOzxY6KPlwstz4Z9CjiUyd+wSiAQNS8hpgAf6nRiwEitoSPKOoOScSZp1wYCnKQlkB8NC8H1XbMtZd9UMR6QHPASY9dGvNJRxid64foN/IjSKq5Udf1HObHDwLRK8SM16lou2+lVOJ7jSZTaaxzjk+xaEbi20EpzLcpQcVoQp31ZuK5+7OzNb2y1IKq95wDNmyGO1dW08UdO2hZaxTBgMHRiSA5V/KXjqmXd/m5l/iJRHE0RcSA4QFzLmEU8x+8x5VmkIwclmpUbgtkGXQSF2em0lnkvBv9X+R4CUxio1Z7droJwXIjBShosSi4gYKVS7GihMeEKkOIKIIAxLuutqx4ZLUS3X+BMGxqWn3TsvVewwK4zSLwJGj30uyUF2nS1ITr/jjRZQccg4TaU2xFta3kdAMKNRCkAAWycCgB0RkgKFP5iJhIahdirDG8KWNTFXnHGh162oiZZbnso7nFCm8nUFvqOg0rXvyd5MVmEzPlFrkcOpVDQ2NwxvB8c77XajaGJPnTlC5WPHjvKrA2OU/6Hc1RLh220OsX8cudqGq4Lp9+L2Hwn9gqk1WGbMXikzbnqZcYdZ/t5S41nZ4F/MBkgGuwrQKOjv9iJmuT5TwckExM2P1JDmaykwTWkaVbOgtyXQzyHGoKurYOxjKslDhjU793Cw6Mb9rK97JD/rUxfZnepSRWXf9HjHNljc6u+e/XzCO05/66MpH0O1WtQvKI59OiSDVxSnT9AHKTQR31VNGa7dOpezAfqt20mHyZjKjYNGnC+DZcfitZJj1kuOe57+VO3WOXPor/uta89nnXwY0+ETXdN9KyiAC3mFFVmIJVnzLF8HnDHwEuVpqze3ZUfYsmezbst25id69uKlevZ8xDWo27OfXpVmJ3r2mYva8GVp0aP1T3cNKEl28VaMZotOEoypT9efZI0/R55+rj2VI2eub4PnSKXj14dIbcJdujePKDv6x64r27YkuVwzDqFY5xJD1N5Nx57M37Bj9/tJ5aGCPfYZox0n+f8WvPT8QU5MYnwACUa2spmtlNxzkYAQVCDX5oXMCkWdItfXTVvHzVLfBsgg1osAE3NIrHYY1fIDGvSnPufBdMJ3pRUYwogjkiv3pKroYW3Mlos9FqGOaX7ZIADbbntEqqZuO69CLM/vEgvZfWLV3yC1ieUPcRS0fzEL7hAlf1QrjM9xDas6ZGcMmBBW1MPwl+JEA80tZGqpWl5auKlW+xKTY60l5Wrz5ECCQuqHVbNPsG5R7Dht7XlzgrtjJzSyb55G6PpDymcx2hsFo3PJhQ5tj0vTl6P0PVcWdHldMi4QBMDw8trjopHnvWFZ9EdBorcoi/8rDk1c+/WaK7xtfuFQHqqbn5Gg2x8=\u0026lt;/diagram\u0026gt;\u0026lt;/mxfile\u0026gt;\" height=\"384px\" style=\"background-color: rgb(255, 255, 255);\" version=\"1.1\" width=\"971px\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\"\u003e\n  \u003cdefs/\u003e\n  \u003cg transform=\"translate(0.5,0.5)\"\u003e\n    \u003cpath d=\"M 480 50 Q 480 100 607.5 100 Q 735 100 735 139.9\" fill=\"none\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" stroke-width=\"1\"/\u003e\n    \u003cpath d=\"M 735 146.65 L 730.5 137.65 L 735 139.9 L 739.5 137.65 Z\" fill=\"#000000\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" stroke-width=\"1\"/\u003e\n    \u003cg transform=\"translate(562.5,93.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"12\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"29\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003eTrue\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Courier New\" font-size=\"12px\" text-anchor=\"middle\" x=\"15\" y=\"12\"\u003eTrue\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cpath d=\"M 480 50 Q 480 100 352.5 100 Q 225 100 225 143.63\" fill=\"none\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\"/\u003e\n    \u003cpath d=\"M 225 148.88 L 221.5 141.88 L 225 143.63 L 228.5 141.88 Z\" fill=\"#000000\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\"/\u003e\n    \u003cg transform=\"translate(316.5,94.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"12\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"36\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003eFalse\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Courier New\" font-size=\"12px\" text-anchor=\"middle\" x=\"18\" y=\"12\"\u003eFalse\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cellipse cx=\"480\" cy=\"25\" fill=\"#ffffff\" pointer-events=\"none\" rx=\"50\" ry=\"25\" stroke=\"#000000\"/\u003e\n    \u003cg transform=\"translate(439.5,6.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"36\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"80\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 80px; white-space: nowrap; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cpre\u003eis_training\u003c/pre\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"40\" y=\"24\"\u003e\u0026lt;pre\u0026gt;is_training\u0026lt;/pre\u0026gt;\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cpath d=\"M 735 200 Q 735 240 674 240 Q 613 240 613 269.9\" fill=\"none\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" stroke-width=\"3\"/\u003e\n    \u003cpath d=\"M 613 276.65 L 608.5 267.65 L 613 269.9 L 617.5 267.65 Z\" fill=\"#000000\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" stroke-width=\"3\"/\u003e\n    \u003cg transform=\"translate(672.5,232.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"12\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"43\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; font-weight: bold; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003eString\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Courier New\" font-size=\"12px\" font-weight=\"bold\" text-anchor=\"middle\" x=\"22\" y=\"12\"\u003eString\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cpath d=\"M 735 200 Q 735 240 800 240 Q 865 240 865 273.63\" fill=\"none\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\"/\u003e\n    \u003cpath d=\"M 865 278.88 L 861.5 271.88 L 865 273.63 L 868.5 271.88 Z\" fill=\"#000000\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\"/\u003e\n    \u003cg transform=\"translate(807.5,233.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"12\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"29\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003eNone\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Courier New\" font-size=\"12px\" text-anchor=\"middle\" x=\"15\" y=\"12\"\u003eNone\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cellipse cx=\"735\" cy=\"175\" fill=\"#ffffff\" pointer-events=\"none\" rx=\"95\" ry=\"25\" stroke=\"#000000\"/\u003e\n    \u003cg transform=\"translate(658.5,156.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"36\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"152\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 152px; white-space: nowrap; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cpre\u003eupdate_ops_collection\u003c/pre\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"76\" y=\"24\"\u003e\u0026lt;pre\u0026gt;\u0026lt;code\u0026gt;update_ops_collection\u0026lt;/code\u0026gt;\u0026lt;/pre\u0026gt;\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cpath d=\"M 225 200 Q 225 240 292.5 240 Q 360 240 360 273.63\" fill=\"none\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\"/\u003e\n    \u003cpath d=\"M 360 278.88 L 356.5 271.88 L 360 273.63 L 363.5 271.88 Z\" fill=\"#000000\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\"/\u003e\n    \u003cg transform=\"translate(260.5,232.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"12\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"36\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003eFalse\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Courier New\" font-size=\"12px\" text-anchor=\"middle\" x=\"18\" y=\"12\"\u003eFalse\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cpath d=\"M 225 200 Q 225 240 165 240 Q 105 240 105 269.9\" fill=\"none\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" stroke-width=\"3\"/\u003e\n    \u003cpath d=\"M 105 276.65 L 100.5 267.65 L 105 269.9 L 109.5 267.65 Z\" fill=\"#000000\" pointer-events=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" stroke-width=\"3\"/\u003e\n    \u003cg transform=\"translate(140.5,234.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"12\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"29\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; font-weight: bold; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003eTrue\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Courier New\" font-size=\"12px\" font-weight=\"bold\" text-anchor=\"middle\" x=\"15\" y=\"12\"\u003eTrue\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003cellipse cx=\"225\" cy=\"175\" fill=\"#ffffff\" pointer-events=\"none\" rx=\"95\" ry=\"25\" stroke=\"#000000\"/\u003e\n    \u003cg transform=\"translate(166.5,156.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"36\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"116\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 116px; white-space: nowrap; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cpre\u003etest_local_stats\u003c/pre\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"58\" y=\"24\"\u003e\u0026lt;pre\u0026gt;\u0026lt;code\u0026gt;test_local_stats\u0026lt;/code\u0026gt;\u0026lt;/pre\u0026gt;\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003crect fill=\"#ffffff\" height=\"60\" pointer-events=\"none\" rx=\"9\" ry=\"9\" stroke=\"#000000\" width=\"210\" x=\"760\" y=\"280\"/\u003e\n    \u003cg transform=\"translate(761.5,270.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"78\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"206\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cul\u003e\n                \u003cli style=\"text-align: left\"\u003eNormalize output using local batch statistics\u003c/li\u003e\n                \u003cli style=\"text-align: left\"\u003eUpdate moving averages in each forward pass\u003c/li\u003e\n              \u003c/ul\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"103\" y=\"45\"\u003e[Not supported by viewer]\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003crect fill=\"#ffffff\" height=\"100\" pointer-events=\"none\" rx=\"15\" ry=\"15\" stroke=\"#000000\" width=\"210\" x=\"508\" y=\"280\"/\u003e\n    \u003cg transform=\"translate(509.5,276.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"106\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"206\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cul\u003e\n                \u003cli style=\"text-align: left\"\u003eNormalize output using local batch statistics\u003c/li\u003e\n                \u003cli style=\"text-align: left\"\u003eUpdate ops for the moving averages are placed in a named collection.\n                  \u003cb\u003eThey are not executed automatically.\u003c/b\u003e\n                \u003c/li\u003e\n              \u003c/ul\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"103\" y=\"59\"\u003e[Not supported by viewer]\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003crect fill=\"#ffffff\" height=\"60\" pointer-events=\"none\" rx=\"9\" ry=\"9\" stroke=\"#000000\" width=\"210\" x=\"255\" y=\"280\"/\u003e\n    \u003cg transform=\"translate(256.5,277.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"64\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"206\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cul\u003e\n                \u003cli style=\"text-align: left\"\u003eNormalize output using stored moving averages.\u003c/li\u003e\n                \u003cli style=\"text-align: left\"\u003eNo update ops are created.\u003c/li\u003e\n              \u003c/ul\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"103\" y=\"38\"\u003e[Not supported by viewer]\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n    \u003crect fill=\"#ffffff\" height=\"60\" pointer-events=\"none\" rx=\"9\" ry=\"9\" stroke=\"#000000\" width=\"210\" x=\"0\" y=\"280\"/\u003e\n    \u003cg transform=\"translate(1.5,277.5)\"\u003e\n      \u003cswitch\u003e\n        \u003cforeignObject height=\"64\" pointer-events=\"all\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\" style=\"overflow:visible;\" width=\"206\"\u003e\n          \u003cdiv style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n            \u003cdiv style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\" xmlns=\"http://www.w3.org/1999/xhtml\"\u003e\n              \u003cul\u003e\n                \u003cli style=\"text-align: left\"\u003eNormalize output using local batch statistics\u003c/li\u003e\n                \u003cli style=\"text-align: left\"\u003eNo update ops are created.\u003c/li\u003e\n              \u003c/ul\u003e\n            \u003c/div\u003e\n          \u003c/div\u003e\n        \u003c/foreignObject\u003e\n        \u003ctext fill=\"#000000\" font-family=\"Helvetica\" font-size=\"12px\" text-anchor=\"middle\" x=\"103\" y=\"38\"\u003e[Not supported by viewer]\u003c/text\u003e\n      \u003c/switch\u003e\n    \u003c/g\u003e\n  \u003c/g\u003e\n\u003c/svg\u003e",
            "text/plain": [
              "\u003cIPython.core.display.SVG at 0xc9aed90\u003e"
            ]
          },
          "metadata": {
            "tags": []
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "#@title Decision tree\n",
        "%%svg\n",
        "\u003c!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\"\u003e\n",
        "\u003csvg\n",
        "  xmlns=\"http://www.w3.org/2000/svg\"\n",
        "  xmlns:xlink=\"http://www.w3.org/1999/xlink\" width=\"971px\" height=\"384px\" version=\"1.1\" content=\"\u0026lt;mxfile userAgent=\u0026quot;Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36\u0026quot; version=\u0026quot;6.7.3\u0026quot; editor=\u0026quot;www.draw.io\u0026quot;\u0026gt;\u0026lt;diagram name=\u0026quot;Page-1\u0026quot;\u0026gt;7Vptc5s4EP41fLQHkMHxx9pNrh96mZtJb3r3ySODDLoKxAnhl/76roR4C7bHl5CEXOuZTORHy2q1++yuZNtCq+Twm8BZ/DsPCbNcOzxY6KPlwstz4Z9CjiUyd+wSiAQNS8hpgAf6nRiwEitoSPKOoOScSZp1wYCnKQlkB8NC8H1XbMtZd9UMR6QHPASY9dGvNJRxid64foN/IjSKq5Udf1HObHDwLRK8SM16lou2+lVOJ7jSZTaaxzjk+xaEbi20EpzLcpQcVoQp31ZuK5+7OzNb2y1IKq95wDNmyGO1dW08UdO2hZaxTBgMHRiSA5V/KXjqmXd/m5l/iJRHE0RcSA4QFzLmEU8x+8x5VmkIwclmpUbgtkGXQSF2em0lnkvBv9X+R4CUxio1Z7droJwXIjBShosSi4gYKVS7GihMeEKkOIKIIAxLuutqx4ZLUS3X+BMGxqWn3TsvVewwK4zSLwJGj30uyUF2nS1ITr/jjRZQccg4TaU2xFta3kdAMKNRCkAAWycCgB0RkgKFP5iJhIahdirDG8KWNTFXnHGh162oiZZbnso7nFCm8nUFvqOg0rXvyd5MVmEzPlFrkcOpVDQ2NwxvB8c77XajaGJPnTlC5WPHjvKrA2OU/6Hc1RLh220OsX8cudqGq4Lp9+L2Hwn9gqk1WGbMXikzbnqZcYdZ/t5S41nZ4F/MBkgGuwrQKOjv9iJmuT5TwckExM2P1JDmaykwTWkaVbOgtyXQzyHGoKurYOxjKslDhjU793Cw6Mb9rK97JD/rUxfZnepSRWXf9HjHNljc6u+e/XzCO05/66MpH0O1WtQvKI59OiSDVxSnT9AHKTQR31VNGa7dOpezAfqt20mHyZjKjYNGnC+DZcfitZJj1kuOe57+VO3WOXPor/uta89nnXwY0+ETXdN9KyiAC3mFFVmIJVnzLF8HnDHwEuVpqze3ZUfYsmezbst25id69uKlevZ8xDWo27OfXpVmJ3r2mYva8GVp0aP1T3cNKEl28VaMZotOEoypT9efZI0/R55+rj2VI2eub4PnSKXj14dIbcJdujePKDv6x64r27YkuVwzDqFY5xJD1N5Nx57M37Bj9/tJ5aGCPfYZox0n+f8WvPT8QU5MYnwACUa2spmtlNxzkYAQVCDX5oXMCkWdItfXTVvHzVLfBsgg1osAE3NIrHYY1fIDGvSnPufBdMJ3pRUYwogjkiv3pKroYW3Mlos9FqGOaX7ZIADbbntEqqZuO69CLM/vEgvZfWLV3yC1ieUPcRS0fzEL7hAlf1QrjM9xDas6ZGcMmBBW1MPwl+JEA80tZGqpWl5auKlW+xKTY60l5Wrz5ECCQuqHVbNPsG5R7Dht7XlzgrtjJzSyb55G6PpDymcx2hsFo3PJhQ5tj0vTl6P0PVcWdHldMi4QBMDw8trjopHnvWFZ9EdBorcoi/8rDk1c+/WaK7xtfuFQHqqbn5Gg2x8=\u0026lt;/diagram\u0026gt;\u0026lt;/mxfile\u0026gt;\" style=\"background-color: rgb(255, 255, 255);\"\u003e\n",
        "  \u003cdefs/\u003e\n",
        "  \u003cg transform=\"translate(0.5,0.5)\"\u003e\n",
        "    \u003cpath d=\"M 480 50 Q 480 100 607.5 100 Q 735 100 735 139.9\" fill=\"none\" stroke=\"#000000\" stroke-width=\"1\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cpath d=\"M 735 146.65 L 730.5 137.65 L 735 139.9 L 739.5 137.65 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-width=\"1\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(562.5,93.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"29\" height=\"12\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\"\u003eTrue\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"15\" y=\"12\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Courier New\"\u003eTrue\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cpath d=\"M 480 50 Q 480 100 352.5 100 Q 225 100 225 143.63\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cpath d=\"M 225 148.88 L 221.5 141.88 L 225 143.63 L 228.5 141.88 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(316.5,94.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"36\" height=\"12\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\"\u003eFalse\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"18\" y=\"12\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Courier New\"\u003eFalse\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cellipse cx=\"480\" cy=\"25\" rx=\"50\" ry=\"25\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(439.5,6.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"80\" height=\"36\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 80px; white-space: nowrap; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cpre\u003eis_training\u003c/pre\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"40\" y=\"24\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e\u0026lt;pre\u0026gt;is_training\u0026lt;/pre\u0026gt;\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cpath d=\"M 735 200 Q 735 240 674 240 Q 613 240 613 269.9\" fill=\"none\" stroke=\"#000000\" stroke-width=\"3\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cpath d=\"M 613 276.65 L 608.5 267.65 L 613 269.9 L 617.5 267.65 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-width=\"3\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(672.5,232.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"43\" height=\"12\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; font-weight: bold; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\"\u003eString\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"22\" y=\"12\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Courier New\" font-weight=\"bold\"\u003eString\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cpath d=\"M 735 200 Q 735 240 800 240 Q 865 240 865 273.63\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cpath d=\"M 865 278.88 L 861.5 271.88 L 865 273.63 L 868.5 271.88 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(807.5,233.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"29\" height=\"12\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\"\u003eNone\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"15\" y=\"12\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Courier New\"\u003eNone\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cellipse cx=\"735\" cy=\"175\" rx=\"95\" ry=\"25\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(658.5,156.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"152\" height=\"36\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 152px; white-space: nowrap; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cpre\u003eupdate_ops_collection\u003c/pre\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"76\" y=\"24\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e\u0026lt;pre\u0026gt;\u0026lt;code\u0026gt;update_ops_collection\u0026lt;/code\u0026gt;\u0026lt;/pre\u0026gt;\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cpath d=\"M 225 200 Q 225 240 292.5 240 Q 360 240 360 273.63\" fill=\"none\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cpath d=\"M 360 278.88 L 356.5 271.88 L 360 273.63 L 363.5 271.88 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(260.5,232.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"36\" height=\"12\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\"\u003eFalse\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"18\" y=\"12\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Courier New\"\u003eFalse\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cpath d=\"M 225 200 Q 225 240 165 240 Q 105 240 105 269.9\" fill=\"none\" stroke=\"#000000\" stroke-width=\"3\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cpath d=\"M 105 276.65 L 100.5 267.65 L 105 269.9 L 109.5 267.65 Z\" fill=\"#000000\" stroke=\"#000000\" stroke-width=\"3\" stroke-miterlimit=\"10\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(140.5,234.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"29\" height=\"12\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: \u0026quot;Courier New\u0026quot;; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; white-space: nowrap; font-weight: bold; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;background-color:#ffffff;\"\u003eTrue\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"15\" y=\"12\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Courier New\" font-weight=\"bold\"\u003eTrue\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003cellipse cx=\"225\" cy=\"175\" rx=\"95\" ry=\"25\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(166.5,156.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"116\" height=\"36\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 116px; white-space: nowrap; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cpre\u003etest_local_stats\u003c/pre\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"58\" y=\"24\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e\u0026lt;pre\u0026gt;\u0026lt;code\u0026gt;test_local_stats\u0026lt;/code\u0026gt;\u0026lt;/pre\u0026gt;\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003crect x=\"760\" y=\"280\" width=\"210\" height=\"60\" rx=\"9\" ry=\"9\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(761.5,270.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"206\" height=\"78\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cul\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eNormalize output using local batch statistics\u003c/li\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eUpdate moving averages in each forward pass\u003c/li\u003e\n",
        "              \u003c/ul\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"103\" y=\"45\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e[Not supported by viewer]\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003crect x=\"508\" y=\"280\" width=\"210\" height=\"100\" rx=\"15\" ry=\"15\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(509.5,276.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"206\" height=\"106\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cul\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eNormalize output using local batch statistics\u003c/li\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eUpdate ops for the moving averages are placed in a named collection.\n",
        "                  \u003cb\u003eThey are not executed automatically.\u003c/b\u003e\n",
        "                \u003c/li\u003e\n",
        "              \u003c/ul\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"103\" y=\"59\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e[Not supported by viewer]\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003crect x=\"255\" y=\"280\" width=\"210\" height=\"60\" rx=\"9\" ry=\"9\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(256.5,277.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"206\" height=\"64\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cul\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eNormalize output using stored moving averages.\u003c/li\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eNo update ops are created.\u003c/li\u003e\n",
        "              \u003c/ul\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"103\" y=\"38\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e[Not supported by viewer]\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "    \u003crect x=\"0\" y=\"280\" width=\"210\" height=\"60\" rx=\"9\" ry=\"9\" fill=\"#ffffff\" stroke=\"#000000\" pointer-events=\"none\"/\u003e\n",
        "    \u003cg transform=\"translate(1.5,277.5)\"\u003e\n",
        "      \u003cswitch\u003e\n",
        "        \u003cforeignObject style=\"overflow:visible;\" pointer-events=\"all\" width=\"206\" height=\"64\" requiredFeatures=\"http://www.w3.org/TR/SVG11/feature#Extensibility\"\u003e\n",
        "          \u003cdiv\n",
        "            xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display: inline-block; font-size: 12px; font-family: Helvetica; color: rgb(0, 0, 0); line-height: 1.2; vertical-align: top; width: 206px; white-space: normal; word-wrap: normal; text-align: center;\"\u003e\n",
        "            \u003cdiv\n",
        "              xmlns=\"http://www.w3.org/1999/xhtml\" style=\"display:inline-block;text-align:inherit;text-decoration:inherit;\"\u003e\n",
        "              \u003cul\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eNormalize output using local batch statistics\u003c/li\u003e\n",
        "                \u003cli style=\"text-align: left\"\u003eNo update ops are created.\u003c/li\u003e\n",
        "              \u003c/ul\u003e\n",
        "            \u003c/div\u003e\n",
        "          \u003c/div\u003e\n",
        "        \u003c/foreignObject\u003e\n",
        "        \u003ctext x=\"103\" y=\"38\" fill=\"#000000\" text-anchor=\"middle\" font-size=\"12px\" font-family=\"Helvetica\"\u003e[Not supported by viewer]\u003c/text\u003e\n",
        "      \u003c/switch\u003e\n",
        "    \u003c/g\u003e\n",
        "  \u003c/g\u003e\n",
        "\u003c/svg\u003e\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "Q5rFcT-gfy8c"
      },
      "outputs": [],
      "source": [
        "#@title Setup\n",
        "\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import sonnet as snt\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import patches\n",
        "%matplotlib inline\n",
        "\n",
        "def run_and_visualize(inputs, outputs, bn_module):\n",
        "  init = tf.global_variables_initializer()\n",
        "  with tf.Session() as sess:\n",
        "    sess.run(init)\n",
        "\n",
        "    inputs_collection = []\n",
        "    outputs_collection = []\n",
        "\n",
        "    for i in range(1000):\n",
        "      current_inputs, current_outputs = sess.run([inputs, outputs])\n",
        "      inputs_collection.append(current_inputs)\n",
        "      outputs_collection.append(current_outputs)\n",
        "\n",
        "    bn_mean, bn_var = sess.run([bn_module._moving_mean,\n",
        "                                bn_module._moving_variance])\n",
        "\n",
        "  inputs_collection = np.concatenate(inputs_collection, axis=0)\n",
        "  outputs_collection = np.concatenate(outputs_collection, axis=0)\n",
        "\n",
        "  print(\"Number of update ops in collection: {}\".format(\n",
        "      len(tf.get_collection(tf.GraphKeys.UPDATE_OPS))))\n",
        "  print(\"Input mean: {}\".format(np.mean(inputs_collection, axis=0)))\n",
        "  print(\"Input variance: {}\".format(np.var(inputs_collection, axis=0)))\n",
        "  print(\"Moving mean: {}\".format(bn_mean))\n",
        "  print(\"Moving variance: {}\".format(bn_var))\n",
        "\n",
        "  plt.figure()\n",
        "  # Plot the learned Gaussian distribution.\n",
        "  ellipse = patches.Ellipse(xy=bn_mean[0], width=bn_var[0, 0],\n",
        "                            height=bn_var[0, 1], angle=0, edgecolor='g',\n",
        "                            fc='None', zorder=1000, linestyle='solid',\n",
        "                            linewidth=2)\n",
        "  # Plot the input distribution.\n",
        "  input_ax = plt.scatter(inputs_collection[:, 0], inputs_collection[:, 1],\n",
        "                         c='r', alpha=0.1, zorder=1)\n",
        "  # Plot the output distribution.\n",
        "  output_ax = plt.scatter(outputs_collection[:, 0], outputs_collection[:, 1],\n",
        "                          c='b', alpha=0.1, zorder=1)\n",
        "  ax = plt.gca()\n",
        "  ellipse_ax = ax.add_patch(ellipse)\n",
        "  plt.legend((input_ax, output_ax, ellipse_ax),\n",
        "             (\"Inputs\", \"Outputs\", \"Aggregated statistics\"),\n",
        "             loc=\"lower right\")\n",
        "  plt.axis(\"equal\")\n",
        "\n",
        "def get_inputs():\n",
        "  return tf.concat([\n",
        "      tf.random_normal((10, 1), 10, 1),\n",
        "      tf.random_normal((10, 1), 10, 2)],\n",
        "      axis=1)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "O5vp5YJd76Az"
      },
      "source": [
        "# Examples"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xvOw3kozfpRz"
      },
      "source": [
        "\n",
        "## Default mode"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "csxsSTd1gVYQ"
      },
      "outputs": [],
      "source": [
        "tf.reset_default_graph()\n",
        "\n",
        "inputs = get_inputs()\n",
        "bn = snt.BatchNorm()\n",
        "outputs = bn(inputs, is_training=True)\n",
        "\n",
        "run_and_visualize(inputs, outputs, bn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "UK8urKxuJB91"
      },
      "source": [
        "**Results**\n",
        "\n",
        "1. The outputs have been normalized. This is indicated by the blue isotropic Gaussian distribution.\n",
        "1. Update ops have been created and placed in a collection. \n",
        "1. No moving statistics have been collected. The green circle shows the learned Gaussian distribution. It is initialized to have mean 0 and standard deviation 1. Because the update ops were created but not executed, these statistics have not been updated.\n",
        "1. The \"boxy\" shape of the normalized data points comes from the rather small batch size of 10. Because the batch statistics are only computed over 10 data points, they are very noisy. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-JYUcQl0k7Jv"
      },
      "source": [
        "## Collecting statistics during training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oXtoqQI8lFpG"
      },
      "source": [
        "### First option: Update statistics automatically on every forward pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "Xw7P8il1k6Zq"
      },
      "outputs": [],
      "source": [
        "tf.reset_default_graph()\n",
        "\n",
        "inputs = get_inputs()\n",
        "bn = snt.BatchNorm(update_ops_collection=None)\n",
        "outputs = bn(inputs, is_training=True)\n",
        "\n",
        "run_and_visualize(inputs, outputs, bn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "E2Nr-mbJJadJ"
      },
      "source": [
        "**Results**\n",
        "1. The outputs have been normalized as we can tell from the blue isotropic Gaussian distribution. \n",
        "1. Update ops have been created and executed. We can see that the moving statistics no longer have their default values (i.e. the green ellipsis has changed). The aggregated statistics don't represent the input distribution yet because we only ran 1000 forward passes."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zR1S6Wqild_e"
      },
      "source": [
        "### Second option: Explicitly add update ops as control dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "pg6f3NDYlU0_"
      },
      "outputs": [],
      "source": [
        "tf.reset_default_graph()\n",
        "\n",
        "inputs = get_inputs()\n",
        "bn = snt.BatchNorm(update_ops_collection=None)\n",
        "outputs = bn(inputs, is_training=True)\n",
        "\n",
        "# Add the update ops as control dependencies\n",
        "# This can usually be done when defining the gradient descent \n",
        "# ops\n",
        "update_ops = tf.group(*tf.get_collection(tf.GraphKeys.UPDATE_OPS))\n",
        "with tf.control_dependencies([update_ops]):\n",
        "  outputs = tf.identity(outputs)\n",
        "  \n",
        "run_and_visualize(inputs, outputs, bn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zK_eP34uKqkR"
      },
      "source": [
        "**Results**\n",
        "\n",
        "The actual results are identical to the previous run. However, this time, the update ops have not been executed automatically whenever we did a forward pass. We have to explicitly make the updates a dependency of our output by using ```tf.control_dependencies```. Usually, we would add the dependencies to our learning ops. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "30fmZgYen8QM"
      },
      "source": [
        "# Using statistics at test time"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "56_FJee2oD1g"
      },
      "source": [
        "## Default mode"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "UgD_D81RoAvH"
      },
      "outputs": [],
      "source": [
        "tf.reset_default_graph()\n",
        "\n",
        "inputs = get_inputs()\n",
        "bn = snt.BatchNorm()\n",
        "outputs = bn(inputs, is_training=False)\n",
        "\n",
        "run_and_visualize(inputs, outputs, bn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Wz_kDgxnLNnJ"
      },
      "source": [
        "**Results**\n",
        "\n",
        "1. No update ops have been created and the moving statistics still have their initial values (mean 0, standard deviation 1). \n",
        "2. The inputs have been normalized using the batch statistics as we can tell from the blue isotropic Gaussian distribution. \n",
        "\n",
        "This means: In the default testing mode, the inputs are normalized using the batch statistics and the aggregated statistics are ignored."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "y9bvLRjuoovi"
      },
      "source": [
        "## Using moving averages at test time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "AkNvdMKjo1g1"
      },
      "outputs": [],
      "source": [
        "def hacky_np_initializer(array):\n",
        "  \"\"\"Allows us to initialize a tf variable with a numpy array.\"\"\"\n",
        "  def _init(shape, dtype, partition_info):\n",
        "    return tf.constant(np.asarray(array, dtype='float32'))\n",
        "  return _init\n",
        "\n",
        "tf.reset_default_graph()\n",
        "\n",
        "inputs = get_inputs()\n",
        "# We initialize the moving mean and variance to non-standard values\n",
        "# so we can see the effect of this setting\n",
        "bn = snt.BatchNorm(initializers={\n",
        "    \"moving_mean\": hacky_np_initializer([[10, 10]]), \n",
        "    \"moving_variance\": hacky_np_initializer([[1, 4]])\n",
        "})\n",
        "outputs = bn(inputs, is_training=False, test_local_stats=False)\n",
        "\n",
        "run_and_visualize(inputs, outputs, bn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4sZQFIFYMAzR"
      },
      "source": [
        "**Results**\n",
        "\n",
        "We have now manually initialized the moving statistics to the moments of the input distribution. We can see that the inputs have been normalized according to our stored statistics. "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "default_view": {},
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "BatchNormDemo.ipynb",
      "provenance": [
        {
          "file_id": "0B0e1jSOBBuRndGdBMlVGYjROalU",
          "timestamp": 1498832161711
        }
      ],
      "version": "0.3.2",
      "views": {}
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
